{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TensorFlow Lattice estimators\n",
    "In this tutorial, we will TensorFlow Lattice estimators.\n",
    "The more detailed version of this notebook can be found in\n",
    "https://github.com/tensorflow/lattice/blob/master/g3doc/tutorial/index.md"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import libraries\n",
    "import pandas as pd\n",
    "import tensorflow as tf\n",
    "import tensorflow_lattice as tfl\n",
    "import tempfile\n",
    "import urllib\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def download_if_not_exists(train_data, test_data):\n",
    "    \"\"\"Maybe downloads training data and returns train and test file names.\"\"\"\n",
    "    train_file_name = train_data\n",
    "    if not os.path.exists(train_file_name):\n",
    "        train_file = tempfile.NamedTemporaryFile(delete=False)\n",
    "        urllib.urlretrieve(\n",
    "            \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data\",\n",
    "            train_file.name)  # pylint: disable=line-too-long\n",
    "        train_file_name = train_file.name\n",
    "        train_file.close()\n",
    "        print(\"Training data is downloaded to %s\" % train_file_name)\n",
    "\n",
    "    test_file_name = test_data\n",
    "    if not os.path.exists(test_file_name):\n",
    "        test_file = tempfile.NamedTemporaryFile(delete=False)\n",
    "        urllib.urlretrieve(\n",
    "            \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test\",\n",
    "            test_file.name)  # pylint: disable=line-too-long\n",
    "        test_file_name = test_file.name\n",
    "        test_file.close()\n",
    "        print(\"Test data is downloaded to %s\"% test_file_name)\n",
    "    \n",
    "    return (train_file_name, test_file_name)\n",
    "\n",
    "# Specify the dataset\n",
    "(TRAIN_DATA, TEST_DATA) = download_if_not_exists(\"/tmp/tfl-data/adult.data\",\n",
    "                                                 \"/tmp/tfl-data/adult.test\") "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "CSV_COLUMNS = [\n",
    "    \"age\", \"workclass\", \"fnlwgt\", \"education\", \"education_num\",\n",
    "    \"marital_status\", \"occupation\", \"relationship\", \"race\", \"gender\",\n",
    "    \"capital_gain\", \"capital_loss\", \"hours_per_week\", \"native_country\",\n",
    "    \"income_bracket\"\n",
    "]\n",
    "\n",
    "def get_input_fn(file_path, batch_size, num_epochs, shuffle):\n",
    "    df_data = pd.read_csv(\n",
    "        tf.gfile.Open(file_path),\n",
    "        names=CSV_COLUMNS,\n",
    "        skipinitialspace=True,\n",
    "        engine=\"python\",\n",
    "        skiprows=1)\n",
    "    # Drop missing for the time being.\n",
    "    df_data = df_data.dropna(how=\"any\", axis=0)\n",
    "    labels = df_data[\"income_bracket\"].apply(lambda x: \">50K\" in x).astype(int)\n",
    "    return tf.estimator.inputs.pandas_input_fn(\n",
    "        x=df_data,\n",
    "        y=labels,\n",
    "        batch_size=batch_size,\n",
    "        num_epochs=num_epochs,\n",
    "        shuffle=shuffle,\n",
    "        num_threads=5)\n",
    "\n",
    "def get_train_input_fn(batch_size, num_epochs=1, shuffle=False):\n",
    "    train_data = TRAIN_DATA\n",
    "    return get_input_fn(train_data, batch_size, num_epochs, shuffle)\n",
    "\n",
    "\n",
    "def densify(fc, make_dense):\n",
    "    if not make_dense:\n",
    "        return fc\n",
    "    return tf.feature_column.embedding_column(fc, 4)\n",
    "\n",
    "\n",
    "def get_feature_columns(make_dense=False):\n",
    "    # Categorical features.\n",
    "    gender = densify(\n",
    "        tf.feature_column.categorical_column_with_vocabulary_list(\n",
    "            \"gender\", [\"Female\", \"Male\"]), make_dense)\n",
    "    education = densify(\n",
    "        tf.feature_column.categorical_column_with_vocabulary_list(\n",
    "            \"education\", [\n",
    "                \"Bachelors\", \"HS-grad\", \"11th\", \"Masters\", \"9th\", \"Some-college\",\n",
    "                \"Assoc-acdm\", \"Assoc-voc\", \"7th-8th\", \"Doctorate\", \"Prof-school\",\n",
    "                \"5th-6th\", \"10th\", \"1st-4th\", \"Preschool\", \"12th\"\n",
    "            ]), make_dense)\n",
    "    marital_status = densify(\n",
    "        tf.feature_column.categorical_column_with_vocabulary_list(\n",
    "            \"marital_status\", [\n",
    "                \"Married-civ-spouse\", \"Divorced\", \"Married-spouse-absent\",\n",
    "                \"Never-married\", \"Separated\", \"Married-AF-spouse\", \"Widowed\"\n",
    "            ]), make_dense)\n",
    "    relationship = densify(\n",
    "        tf.feature_column.categorical_column_with_vocabulary_list(\n",
    "            \"relationship\", [\n",
    "                \"Husband\", \"Not-in-family\", \"Wife\", \"Own-child\", \"Unmarried\",\n",
    "                \"Other-relative\"\n",
    "            ]), make_dense)\n",
    "    workclass = densify(\n",
    "        tf.feature_column.categorical_column_with_vocabulary_list(\n",
    "            \"workclass\", [\n",
    "                \"Self-emp-not-inc\", \"Private\", \"State-gov\", \"Federal-gov\",\n",
    "                \"Local-gov\", \"?\", \"Self-emp-inc\", \"Without-pay\", \"Never-worked\"\n",
    "            ]), make_dense)\n",
    "\n",
    "    # To show an example of hashing:\n",
    "    occupation = densify(\n",
    "        tf.feature_column.categorical_column_with_hash_bucket(\n",
    "            \"occupation\", hash_bucket_size=1000), make_dense)\n",
    "    native_country = densify(\n",
    "        tf.feature_column.categorical_column_with_hash_bucket(\n",
    "            \"native_country\", hash_bucket_size=1000), make_dense)\n",
    "\n",
    "    # Continuous base columns.\n",
    "    age = tf.feature_column.numeric_column(\"age\")\n",
    "    education_num = tf.feature_column.numeric_column(\"education_num\")\n",
    "    capital_gain = tf.feature_column.numeric_column(\"capital_gain\")\n",
    "    capital_loss = tf.feature_column.numeric_column(\"capital_loss\")\n",
    "    hours_per_week = tf.feature_column.numeric_column(\"hours_per_week\")\n",
    "    \n",
    "    return [\n",
    "        age,\n",
    "        education_num,\n",
    "        capital_gain,\n",
    "        capital_loss,\n",
    "        hours_per_week,\n",
    "        gender,\n",
    "        education,\n",
    "        marital_status,\n",
    "    ]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create a histogram\n",
    "This information will be used to initialize the calibrator input keypoints."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "quantiles_dir = tempfile.mkdtemp()\n",
    "\n",
    "def create_histogram(quantiles_dir):\n",
    "    input_fn = get_train_input_fn(batch_size=10000, num_epochs=1, shuffle=False)\n",
    "    tfl.save_quantiles_for_keypoints(\n",
    "        input_fn=input_fn,\n",
    "        save_dir=quantiles_dir,\n",
    "        feature_columns=get_feature_columns(make_dense=False),\n",
    "        num_steps=10)\n",
    "    \n",
    "create_histogram(quantiles_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Estimator!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _build_linear_estimator(model_dir, feature_columns, learning_rate):\n",
    "    \"\"\"Build linear estimator.\"\"\"\n",
    "    feature_names = [fc.name for fc in feature_columns]\n",
    "    hparams = tfl.CalibratedLinearHParams(\n",
    "        feature_names=feature_names,\n",
    "        learning_rate=learning_rate,\n",
    "        num_keypoints=20)\n",
    "\n",
    "    m = tfl.calibrated_linear_classifier(\n",
    "        model_dir=model_dir,\n",
    "        quantiles_dir=quantiles_dir,\n",
    "        feature_columns=feature_columns,\n",
    "        hparams=hparams)\n",
    "    return m\n",
    "\n",
    "def _build_rtl_estimator(model_dir, feature_columns, learning_rate):\n",
    "    \"\"\"Build rtl estimator.\"\"\"\n",
    "    feature_names = [fc.name for fc in feature_columns]\n",
    "    # Create 100 number of 2 x 2 x 2 x 2 lattices.\n",
    "    hparams = tfl.CalibratedRtlHParams(\n",
    "        feature_names=feature_names,\n",
    "        learning_rate=learning_rate,\n",
    "        lattice_rank=4,\n",
    "        num_lattices=100,\n",
    "        num_keypoints=20)\n",
    "    \n",
    "    m = tfl.calibrated_rtl_classifier(\n",
    "        model_dir=model_dir,\n",
    "        quantiles_dir=quantiles_dir,\n",
    "        feature_columns=feature_columns,\n",
    "        hparams=hparams)\n",
    "    return m\n",
    "\n",
    "def build_estimator(model_dir, learning_rate, model_type='rtl'):\n",
    "    \"\"\"Build an estimator.\"\"\"\n",
    "    if not tf.gfile.Exists(model_dir):\n",
    "        tf.gfile.MkDir(model_dir)\n",
    "    feature_columns = get_feature_columns(make_dense=False)\n",
    "    if model_type == 'rtl':\n",
    "        return _build_rtl_estimator(model_dir, feature_columns, learning_rate)\n",
    "    elif model_type == 'linear':\n",
    "        return _build_linear_estimator(model_dir, feature_columns, learning_rate)\n",
    "    else:\n",
    "        raise ValueError('unsupported model type: %s' % model_type)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train linear model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_dir = tempfile.mkdtemp()\n",
    "learning_rate = 0.01\n",
    "batch_size = 100\n",
    "\n",
    "m = build_estimator(model_dir, learning_rate, model_type='linear')\n",
    "m.train(input_fn=get_train_input_fn(batch_size=100, num_epochs=1, shuffle=True))\n",
    "\n",
    "print('=====Training set=====')\n",
    "results = m.evaluate(input_fn=get_train_input_fn(batch_size=batch_size))\n",
    "for key in sorted(results):\n",
    "    print('%s: %s' % (key, results[key]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train RTL model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_dir = tempfile.mkdtemp()\n",
    "learning_rate = 0.01\n",
    "batch_size = 100\n",
    "\n",
    "m = build_estimator(model_dir, learning_rate, model_type='rtl')\n",
    "m.train(input_fn=get_train_input_fn(batch_size=100, num_epochs=1, shuffle=True))\n",
    "\n",
    "print('=====Training set=====')\n",
    "results = m.evaluate(input_fn=get_train_input_fn(batch_size=batch_size))\n",
    "for key in sorted(results):\n",
    "    print('%s: %s' % (key, results[key]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tf-lattice",
   "language": "python",
   "name": "tf-lattice"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}

